import os

from pyspark.sql import SparkSession
from pyspark.sql.functions import input_file_name, regexp_extract, round
from pyspark.sql.window import Window
from pyspark.sql.functions import avg, col
from pyspark.sql.types import StructType, StructField, DateType, DecimalType, StringType, IntegerType

if __name__ == "__main__":
    spark = SparkSession \
        .builder \
        .master("local[3]") \
        .appName("sample2") \
        .getOrCreate()

    csvDF = spark.read \
        .format("csv") \
        .option("header", "true") \
        .load("data/fund/FundPrice*.csv") \
        .withColumn("fundid", regexp_extract(input_file_name(), r"FundPrice(\d+)\.csv", 1))

    # Define a window specification to calculate the average of the last 5 days
    window_spec = Window.partitionBy("fundid").orderBy("date").rowsBetween(-4, 0)

    # Add a new column for the average accumulate price, rounded to 3 decimal places
    csvDF = csvDF.withColumn("avgAccumulatePrice", round(avg(col("accumulativePrice")).over(window_spec), 3))

    print(csvDF.count())
    csvDF.show()
    print(csvDF.schema.json())

    # Define the schema for the Parquet file
    parquet_schema = StructType([
        StructField("Date", DateType(), True),
        StructField("fundid", StringType(), True),
        StructField("accumulativePrice", DecimalType(10, 3), True),
        StructField("avgAccumulatePrice", DecimalType(10, 3), True),
        StructField("isbonus", IntegerType(), True),  # Added as IntegerType
        StructField("isCash", IntegerType(), True),    # Added as IntegerType
        StructField("price", DecimalType(10, 3), True),  # Added as DecimalType
        StructField("bonusPerStock", DecimalType(10, 3), True)  # Added as DecimalType
    ])

    # Apply the schema when saving as Parquet
    csvDF = csvDF.select(
        col("Date").cast(DateType()),
        col("fundid"),
        col("accumulativePrice").cast(DecimalType(10, 3)),
        col("avgAccumulatePrice").cast(DecimalType(10, 3)),
        col("isbonus").cast(IntegerType()),  # Cast to IntegerType
        col("isCash").cast(IntegerType()),    # Cast to IntegerType
        col("price").cast(DecimalType(10, 3)),  # Cast to DecimalType
        col("bonusPerStock").cast(DecimalType(10, 3))  # Cast to DecimalType
    )

    # coalesce(1) is used to save it as one file
    csvDF.coalesce(1).write \
        .format("parquet") \
        .mode("overwrite") \
        .option("path", "data/fund/parquet/") \
        .save()
    
    parquetDF = spark.read \
        .format("parquet") \
        .load("data/fund/parquet/*.parquet")
    parquetDF.show(5)
    print(parquetDF.schema.json())

    spark.stop()



